{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何使用tf.train.Saver() 把多个模型融合起来\n",
    "下面代码展示了怎么把多个独立训练好的模型联合起来fine-tune。因为比赛中没注意，导致最后有个模型没有保存好。下面的例子中用了两个模型，多个模型同样可以按照下面的方式进行。有几点需要注意：\n",
    "\n",
    "- 在训练单个模型的时候，一定一定一定要把模型的定义放在一个大的 variable_scope 中，这样在多个模型融合的时候才不会出现命名冲突。\n",
    "- 模型融合需要用到的变量，比如最后的一个全连接层，一定要写好接口。\n",
    "- 将模型融合以后，一定一定一定要重新定义一个Saver()，这样才会把融合的大模型中所有的变量存到checkpoint中。\n",
    "\n",
    "下面3部分分别表示：\n",
    "- 1.分别训练单个模型（注意在不同的variable_scope中定义），保存单个模型。\n",
    "- 2.载入多个模型，进行 fine-tune。融合以后定义新的 Saver()， 保存 fine-tune 好的模型。\n",
    "- 3.载入 fine-tune 好的模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.训练单个模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../ckpt/saver_test/model1/model.ckpt-1\n",
      "../ckpt/saver_test/model2/model.ckpt-1\n"
     ]
    }
   ],
   "source": [
    "# ** 1.分别训练单个模型（注意在不同的variable_scope中定义），保存单个模型。\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import wd_1_1_cnn_concat.network as network1 \n",
    "import wd_1_2_cnn_max.network as network2\n",
    "\n",
    "setting1 = network1.Settings()\n",
    "setting2 = network2.Settings()\n",
    "W_embedding = np.random.randn(5, 6)\n",
    "\n",
    "\n",
    "# 下面两个模型\n",
    "with tf.variable_scope('model1') as vs:\n",
    "    model1 = network1.TextCNN(settings=setting1, W_embedding=W_embedding)\n",
    "\n",
    "with tf.variable_scope('model2') as vs:\n",
    "    model2 = network2.TextCNN(settings=setting2, W_embedding=W_embedding)\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# 保存第1个模型\n",
    "ckpt_path1 = '../ckpt/saver_test/model1/model.ckpt'\n",
    "save_path1 = model1.saver.save(sess, ckpt_path1, global_step=1)\n",
    "print(save_path1)\n",
    "\n",
    "    \n",
    "# 保存第2个模型\n",
    "ckpt_path2 = '../ckpt/saver_test/model2/model.ckpt'\n",
    "save_path2 = model2.saver.save(sess, ckpt_path2, global_step=1)\n",
    "print(save_path2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.载入多个模型\n",
    "**注意，先 restart kernel， 再执行下面的代码**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../ckpt/saver_test/model1/model.ckpt-1\n",
      "Tensor(\"model1/out_layer/y_pred:0\", shape=(?, 1999), dtype=float32)\n",
      "INFO:tensorflow:Restoring parameters from ../ckpt/saver_test/model2/model.ckpt-1\n",
      "Tensor(\"model2/out_layer/y_pred:0\", shape=(?, 1999), dtype=float32)\n",
      "Tensor(\"new_vars/concat:0\", shape=(?, 3998), dtype=float32)\n",
      "../ckpt/saver_test/model_all/model.ckpt-1\n"
     ]
    }
   ],
   "source": [
    "# ** 2.载入多个模型，进行 fine-tune。融合以后定义新的 Saver()， 保存 fine-tune 好的模型。\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import wd_1_1_cnn_concat.network as network1 \n",
    "import wd_1_2_cnn_max.network as network2\n",
    "\n",
    "setting1 = network1.Settings()\n",
    "setting2 = network2.Settings()\n",
    "W_embedding = np.random.randn(5, 6)\n",
    "\n",
    "\n",
    "# 下面两个模型\n",
    "with tf.variable_scope('model1') as vs:\n",
    "    model1 = network1.TextCNN(settings=setting1, W_embedding=W_embedding)\n",
    "\n",
    "with tf.variable_scope('model2') as vs:\n",
    "    model2 = network2.TextCNN(settings=setting2, W_embedding=W_embedding)\n",
    "    \n",
    "# 保存第1个模型\n",
    "ckpt_path1 = '../ckpt/saver_test/model1/model.ckpt'\n",
    "model1.saver.restore(sess, ckpt_path1 + '-' + str(1))\n",
    "print(model1.y_pred)\n",
    "    \n",
    "# 保存第2个模型\n",
    "ckpt_path2 = '../ckpt/saver_test/model2/model.ckpt'\n",
    "model2.saver.restore(sess, ckpt_path2 + '-' + str(1))\n",
    "print(model2.y_pred)\n",
    "\n",
    "# 创建新的变量\n",
    "with tf.variable_scope('new_vars') as vs:\n",
    "    y_pred = tf.concat([model1.y_pred, model2.y_pred], axis=1)\n",
    "    new_vars = [v for v in tf.global_variables() if v.name.startswith(vs.name+'/')]\n",
    "# 初始化新的变量\n",
    "sess.run(tf.variables_initializer(new_vars))\n",
    "print(y_pred)\n",
    "\n",
    "saver = tf.train.Saver(max_to_keep=1)\n",
    "ckpt_path_all =  '../ckpt/saver_test/model_all/model.ckpt'\n",
    "save_path_all = saver.save(sess, ckpt_path_all, global_step=1)\n",
    "print(save_path_all)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 载入 fine-tune 好的大模型\n",
    "**注意，先 restart kernel， 再执行下面的代码**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../ckpt/saver_test/model_all/model.ckpt-1\n",
      "Tensor(\"new_vars/concat:0\", shape=(?, 3998), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# ** 3.载入 fine-tune 好的模型。\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import wd_1_1_cnn_concat.network as network1 \n",
    "import wd_1_2_cnn_max.network as network2\n",
    "\n",
    "setting1 = network1.Settings()\n",
    "setting2 = network2.Settings()\n",
    "W_embedding = np.random.randn(5, 6)\n",
    "\n",
    "\n",
    "# 下面两个模型\n",
    "with tf.variable_scope('model1') as vs:\n",
    "    model1 = network1.TextCNN(settings=setting1, W_embedding=W_embedding)\n",
    "\n",
    "with tf.variable_scope('model2') as vs:\n",
    "    model2 = network2.TextCNN(settings=setting2, W_embedding=W_embedding)\n",
    "\n",
    "# 创建新的变量\n",
    "with tf.variable_scope('new_vars') as vs:\n",
    "    y_pred = tf.concat([model1.y_pred, model2.y_pred], axis=1)    \n",
    "\n",
    "saver = tf.train.Saver(max_to_keep=1)\n",
    "ckpt_path_all =  '../ckpt/saver_test/model_all/model.ckpt'\n",
    "saver.restore(sess, ckpt_path_all + '-' + str(1))\n",
    "print(y_pred)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
